{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 12、如何在PyTorch中训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision as tv\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device= cuda\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"device=\",device)\n",
    "# 搭建数据集\n",
    "train_data = tv.datasets.MNIST(\n",
    "    root=\"./mnist_dataset\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=tv.transforms.ToTensor(),\n",
    ")\n",
    "test_data = tv.datasets.MNIST(\n",
    "    root=\"./mnist_dataset\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=tv.transforms.ToTensor(),\n",
    ")\n",
    "batch_size = 64\n",
    "train_iter = torch.utils.data.DataLoader(\n",
    "    dataset=train_data, shuffle=True, batch_size=batch_size, num_workers=0\n",
    ")\n",
    "test_iter = torch.utils.data.DataLoader(\n",
    "    dataset=test_data, shuffle=False, batch_size=batch_size, num_workers=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/30], 训练损失: 0.0188, 测试损失: 0.0090, 测试准确率: 0.8700\n",
      "Epoch [2/30], 训练损失: 0.0076, 测试损失: 0.0063, 测试准确率: 0.8956\n",
      "Epoch [3/30], 训练损失: 0.0061, 测试损失: 0.0054, 测试准确率: 0.9042\n",
      "Epoch [4/30], 训练损失: 0.0054, 测试损失: 0.0050, 测试准确率: 0.9110\n",
      "Epoch [5/30], 训练损失: 0.0051, 测试损失: 0.0047, 测试准确率: 0.9168\n",
      "Epoch [6/30], 训练损失: 0.0048, 测试损失: 0.0044, 测试准确率: 0.9212\n",
      "Epoch [7/30], 训练损失: 0.0045, 测试损失: 0.0042, 测试准确率: 0.9241\n",
      "Epoch [8/30], 训练损失: 0.0043, 测试损失: 0.0041, 测试准确率: 0.9276\n",
      "Epoch [9/30], 训练损失: 0.0041, 测试损失: 0.0039, 测试准确率: 0.9296\n",
      "Epoch [10/30], 训练损失: 0.0040, 测试损失: 0.0038, 测试准确率: 0.9324\n",
      "Epoch [11/30], 训练损失: 0.0038, 测试损失: 0.0036, 测试准确率: 0.9337\n",
      "Epoch [12/30], 训练损失: 0.0037, 测试损失: 0.0035, 测试准确率: 0.9365\n",
      "Epoch [13/30], 训练损失: 0.0036, 测试损失: 0.0034, 测试准确率: 0.9387\n",
      "Epoch [14/30], 训练损失: 0.0034, 测试损失: 0.0033, 测试准确率: 0.9403\n",
      "Epoch [15/30], 训练损失: 0.0033, 测试损失: 0.0032, 测试准确率: 0.9414\n",
      "Epoch [16/30], 训练损失: 0.0032, 测试损失: 0.0031, 测试准确率: 0.9424\n",
      "Epoch [17/30], 训练损失: 0.0031, 测试损失: 0.0030, 测试准确率: 0.9435\n",
      "Epoch [18/30], 训练损失: 0.0030, 测试损失: 0.0029, 测试准确率: 0.9448\n",
      "Epoch [19/30], 训练损失: 0.0029, 测试损失: 0.0029, 测试准确率: 0.9471\n",
      "Epoch [20/30], 训练损失: 0.0029, 测试损失: 0.0028, 测试准确率: 0.9465\n",
      "Epoch [21/30], 训练损失: 0.0028, 测试损失: 0.0027, 测试准确率: 0.9476\n",
      "Epoch [22/30], 训练损失: 0.0027, 测试损失: 0.0027, 测试准确率: 0.9499\n",
      "Epoch [23/30], 训练损失: 0.0026, 测试损失: 0.0026, 测试准确率: 0.9512\n",
      "Epoch [24/30], 训练损失: 0.0026, 测试损失: 0.0025, 测试准确率: 0.9523\n",
      "Epoch [25/30], 训练损失: 0.0025, 测试损失: 0.0025, 测试准确率: 0.9531\n",
      "Epoch [26/30], 训练损失: 0.0024, 测试损失: 0.0024, 测试准确率: 0.9532\n",
      "Epoch [27/30], 训练损失: 0.0024, 测试损失: 0.0024, 测试准确率: 0.9542\n",
      "Epoch [28/30], 训练损失: 0.0023, 测试损失: 0.0023, 测试准确率: 0.9556\n",
      "Epoch [29/30], 训练损失: 0.0023, 测试损失: 0.0023, 测试准确率: 0.9565\n",
      "Epoch [30/30], 训练损失: 0.0022, 测试损失: 0.0023, 测试准确率: 0.9579\n"
     ]
    }
   ],
   "source": [
    "loss_history = {}\n",
    "num_epoch = 30\n",
    "loss_function = nn.CrossEntropyLoss()\n",
    "net = torch.nn.Sequential(\n",
    "    nn.Flatten(),\n",
    "    nn.Linear(784, 256),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(256, 10),\n",
    ").to(device)\n",
    "\n",
    "loss_history[\"Train Loss\"] = []\n",
    "loss_history[\"Test Loss\"] = []\n",
    "loss_history[\"Test Acc\"] = []\n",
    "for epoch in range(30):\n",
    "    optimizer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
    "    train_loss_epoch, item_cnt = 0, 0\n",
    "    for x, y in train_iter:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "\n",
    "        pred = net(x) # 预测\n",
    "        los = loss_function(pred, y) # 计算梯度\n",
    "        los.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        train_loss_epoch += los.item()\n",
    "        item_cnt += y.shape[0]\n",
    "\n",
    "    train_loss_epoch /= item_cnt\n",
    "    loss_history[\"Train Loss\"].append(train_loss_epoch)\n",
    "    with torch.no_grad():\n",
    "        test_loss_epoch, test_acc_epoch, item_cnt = 0, 0, 0\n",
    "        for x, y in test_iter:\n",
    "            x, y = x.to(device), y.to(device)\n",
    "            pred = net(x)\n",
    "            los = loss_function(pred, y)\n",
    "            test_loss_epoch += los.item()\n",
    "            test_acc_epoch += (pred.argmax(dim=1).long() == y.long()).sum().item()\n",
    "            item_cnt += y.shape[0]\n",
    "\n",
    "    test_loss_epoch /= item_cnt\n",
    "    loss_history[\"Test Loss\"].append(test_loss_epoch)\n",
    "    test_acc_epoch /= item_cnt\n",
    "    loss_history[\"Test Acc\"].append(test_acc_epoch)\n",
    "    print(\n",
    "        f\"Epoch [{epoch + 1}/{num_epoch}], 训练损失: {train_loss_epoch:.4f}, 测试损失: {test_loss_epoch:.4f}, 测试准确率: {test_acc_epoch:.4f}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
